import argparse

import torch
import numpy as np

from geotransformer.utils.data import registration_collate_fn_stack_mode
from geotransformer.utils.torch import to_cuda, release_cuda
from geotransformer.utils.open3d import make_open3d_point_cloud, get_color, draw_geometries
from geotransformer.utils.open3d import *
from geotransformer.utils.registration import compute_registration_error

from config import make_cfg
from model import create_model
from skimage import color


def make_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--src_file", required=True, help="src point cloud numpy file")
    parser.add_argument("--ref_file", required=True, help="ref point cloud numpy file")
    parser.add_argument("--gt_file", required=True, help="ground-truth transformation file")
    parser.add_argument("--weights", required=True, help="model weights file")
    return parser


def load_data(args):
    
    src_data = np.load(args.src_file)
    ref_data = np.load(args.ref_file)

    src_points = src_data[:, :3]
    ref_points = ref_data[:, :3]

    src_color = src_data[:, 3:6] if src_data.shape[1] > 3 else np.ones((src_points.shape[0], 3))
    ref_color = ref_data[:, 3:6] if ref_data.shape[1] > 3 else np.ones((ref_points.shape[0], 3))
    

    if src_color.max() > 1.0:
        src_color = src_color / 255.0
    if ref_color.max() > 1.0:
        ref_color = ref_color / 255.0

    # HSV
    ref_color_hsv = color.rgb2hsv(ref_color)
    src_color_hsv = color.rgb2hsv(src_color)
    hsv = np.concatenate([ref_color_hsv, src_color_hsv], axis=0)  # 

    data_dict = {
        "ref_points": ref_points.astype(np.float32),
        "src_points": src_points.astype(np.float32),
        "ref_feats": np.ones((ref_points.shape[0], 1), dtype=np.float32),
        "src_feats": np.ones((src_points.shape[0], 1), dtype=np.float32),
        "ref_color": ref_color.astype(np.float32),
        "src_color": src_color.astype(np.float32),
        "ref_color_hsv": ref_color_hsv.astype(np.float32),  # 
        "src_color_hsv": src_color_hsv.astype(np.float32),  # 
        "hsv": hsv.astype(np.float32),  # 
    }

    if args.gt_file:
        transform = np.load(args.gt_file).astype(np.float32)
        data_dict["transform"] = transform

    return data_dict


def main():
    parser = make_parser()
    args = parser.parse_args()

    cfg = make_cfg()

    # 
    data_dict = load_data(args)
    neighbor_limits = [38, 36, 36, 38]  # default setting in 3DMatch

    data_dict = registration_collate_fn_stack_mode(
        [data_dict], cfg.backbone.num_stages, cfg.backbone.init_voxel_size, cfg.backbone.init_radius, neighbor_limits
    )

    #
    model = create_model(cfg).cuda()
    state_dict = torch.load(args.weights)
    model.load_state_dict(state_dict["model"])

    #
    data_dict = to_cuda(data_dict)
    output_dict = model(data_dict)
    data_dict = release_cuda(data_dict)
    output_dict = release_cuda(output_dict)

    #
    ref_points = output_dict["ref_points"]
    src_points = output_dict["src_points"]
    estimated_transform = output_dict["estimated_transform"]
    transform = data_dict["transform"]

    #
    ref_color = data_dict["ref_color"]
    src_color = data_dict["src_color"]

    #
    ref_pcd = make_open3d_point_cloud(ref_points)
    ref_pcd.estimate_normals()

    #
    ref_pcd = make_open3d_point_cloud(ref_points)
    ref_pcd.estimate_normals()
    ref_pcd.paint_uniform_color(get_color("custom_yellow"))
    src_pcd = make_open3d_point_cloud(src_points)
    src_pcd.estimate_normals()
    src_pcd.paint_uniform_color(get_color("custom_blue"))
    draw_geometries(ref_pcd, src_pcd)
    #

    ref_pcd.colors = o3d.utility.Vector3dVector(ref_color)

    src_pcd = make_open3d_point_cloud(src_points)
    src_pcd.estimate_normals()
    src_pcd.colors = o3d.utility.Vector3dVector(src_color)

    # #
    draw_geometries(ref_pcd, src_pcd)

    #
    ref_corr_points = output_dict.get("ref_corr_points", None)
    src_corr_points = output_dict.get("src_corr_points", None)

    if ref_corr_points is not None and src_corr_points is not None and ref_corr_points.shape[0] > 0:
        #
        offset = np.array([3, 0.0, 0.0])
        src_corr_points += offset  #

        # 
        corr_lines = make_open3d_corr_lines(ref_corr_points, src_corr_points, label="pos")

        # 
        draw_geometries(ref_pcd, src_pcd, corr_lines)

        # ✅ 
        src_pcd.translate(offset)

        # 
        draw_geometries(ref_pcd, src_pcd, corr_lines)

    # 
    src_pcd_transformed = src_pcd.transform(estimated_transform)
    draw_geometries(ref_pcd, src_pcd_transformed)

    # 
    rre, rte = compute_registration_error(transform, estimated_transform)
    print(f"RRE(deg): {rre:.3f}, RTE(m): {rte:.3f}")


if __name__ == "__main__":
    main()
